import torch
import torch.nn as nn
import math
import numpy as np
import torch.nn.functional as F
from .utils import *

class GPN(nn.Module):
    def __init__(self, features, n_hidden, edge_indices_no_diag, idx_train, labels, leaky_rate, adj, dropout, T):
        super(GPN, self).__init__()
        self.features = features
        self.in_features = features.shape[1]
        self.out_features = n_hidden
        self.num_classes = max(labels) + 1
        self.num_nodes = features.shape[0]
        
        self.edge_indices_no_diag = edge_indices_no_diag
        self.idx_train = idx_train
        self.train_labels = labels[idx_train]
        
        self.W = nn.Linear(self.in_features, self.out_features, bias=False)
        self.a = nn.Parameter(torch.Tensor(2 * self.out_features, 1))
        
        self.ft = nn.Linear(self.in_features, self.num_classes, bias=False)
        
        self.act = nn.LeakyReLU(leaky_rate)
        self.isadj = adj[0]
        self.adj = adj[1]
        self.dropout = dropout
        self.T = T
        self.tmp = []
        
        self.reset_parameters()
    
    def reset_parameters(self):
        glorot(self.W.weight)
        glorot(self.a)
        glorot(self.ft.weight)
    
    
    def forward(self, h):
        Wh = self.W(h)
        self.A_ds_no_diag = self.calculate_adj(Wh)

        return self.poisson_conv()

    
    def poisson_conv(self):
        B = torch.zeros([self.num_nodes, self.num_classes]).cuda()
        B[self.idx_train, self.train_labels] = 1
        B[self.idx_train, :] = B[self.idx_train, :] - B.sum(0, keepdim=True) / len(self.idx_train)
        
        # D = self.A_ds_no_diag + 1e-10 * torch.eye(self.num_nodes).cuda()
        # D = torch.sum(D, 1) ** -1
        # D = torch.diag(D)
        
        # P = torch.mm(D, self.A_ds_no_diag.t())
        # DB = torch.mm(D, B)
        
        ut = torch.zeros([self.num_nodes, self.num_classes]).cuda()
        P = self.torch_sparse(self.A_ds_no_diag)
        T = 0
        while T < self.T:
            ut = torch.sparse.mm(P, ut) + B
            T = T + 1
            if (not self.isadj) and T == self.T - 3:
                ut = ut + self.ft(self.features)
                ut = F.dropout(ut, self.dropout, training=self.training)
                
        return ut
    
    
    def calculate_adj(self, Wh):
        indices = self.edge_indices_no_diag.clone()
        feat1 = Wh[indices[0, :], :]
        feat2 = Wh[indices[1, :], :]
        feat = torch.cat((feat1, feat2), 1)
        atten_coef = torch.exp(self.act(torch.mm(feat, self.a))).flatten()
        atten = torch.zeros([self.num_nodes, self.num_nodes]).cuda()
        atten[indices[0, :], indices[1, :]] = atten_coef
        pos = torch.where(atten.sum(1)==0)[0]
        atten[pos, pos] = 1
        atten = atten.t() / atten.sum(1)
        return atten.t()
    
    def torch_sparse(self, A):
        idx = torch.nonzero(A).T
        data = A[idx[0], idx[1]]
        return torch.sparse_coo_tensor(idx, data, A.shape, device=A.device)